{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Maximizing NLP model performance using automatic model tuning in Amazon SageMaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to fine tune natural language processing (NLP) models in Amazon SageMaker and do automatic model tunning using hyperparameter optimization. We use the Hugging Face's [pytorch-transformers](https://github.com/huggingface/pytorch-transformers) as example code and library to build and train models.\n",
    "\n",
    "There are two datasets to be used in this demo. One is the MRPC data for the General Language Understanding Evaluation ([GLUE](https://gluebenchmark.com/tasks/)) task, and the other is [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) 1.1 data for questions and answering.\n",
    "\n",
    "More Amazon SageMaker hyperparameter tunning notebook examples can be found [here](https://github.com/awslabs/amazon-sagemaker-examples/tree/master/hyperparameter_tuning)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From this blog post:  https://aws.amazon.com/blogs/machine-learning/maximizing-nlp-model-performance-with-automatic-model-tuning-in-amazon-sagemaker/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and training script preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download data and code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GLUE data can be download by using this [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and extracting CoLA...\n",
      "\tCompleted!\n",
      "Downloading and extracting SST...\n",
      "\tCompleted!\n",
      "Processing MRPC...\n",
      "Local MRPC data not specified, downloading data from https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt\n",
      "\tCompleted!\n",
      "Downloading and extracting QQP...\n",
      "\tCompleted!\n",
      "Downloading and extracting STS...\n",
      "\tCompleted!\n",
      "Downloading and extracting MNLI...\n",
      "\tCompleted!\n",
      "Downloading and extracting SNLI...\n",
      "\tCompleted!\n",
      "Downloading and extracting QNLI...\n",
      "\tCompleted!\n",
      "Downloading and extracting RTE...\n",
      "\tCompleted!\n",
      "Downloading and extracting WNLI...\n",
      "\tCompleted!\n",
      "Downloading and extracting diagnostic...\n",
      "\tCompleted!\n"
     ]
    }
   ],
   "source": [
    "# Download all GLUE data to a local folder\n",
    "\n",
    "!python download_glue_data.py --data_dir glue_data --tasks all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training scripts can be download with git cloning [pytorch-transformers](https://github.com/huggingface/pytorch-transformers). The `examples` folder has training script `run_glue.py` for GLUE data and  `run_squad.py` for SQuAD data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'pytorch-transformers'...\n",
      "remote: Enumerating objects: 143, done.\u001b[K\n",
      "remote: Counting objects: 100% (143/143), done.\u001b[K\n",
      "remote: Compressing objects: 100% (54/54), done.\u001b[K\n",
      "remote: Total 22228 (delta 70), reused 122 (delta 62), pack-reused 22085\u001b[K\n",
      "Receiving objects: 100% (22228/22228), 13.18 MiB | 47.01 MiB/s, done.\n",
      "Resolving deltas: 100% (15949/15949), done.\n"
     ]
    }
   ],
   "source": [
    "# Download GitHub code to local machine\n",
    "\n",
    "!git clone https://github.com/huggingface/pytorch-transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Modify scripts for Amazon SageMaker use\n",
    "\n",
    "To avoid editing the scripts inside the git folder, we copied the relevant python scripts from the folder ./pytorch-transformers/examples/ to ./train_scripts/. \n",
    "\n",
    "We made minimal changes to run_glue.py and run_squad.py to make them work with the Amazon SageMaker PyTorch framework. The changes can be found by checking the comments for `'for SageMaker use'`. These changes are largely around the way to pass arugments to the python script. In Amazon SageMaker, the easiest way to pass input arguments is as hyperparameters passed to the training job. Here are some examples of the changes made to the script:\n",
    "\n",
    "The original run_glue.py treats argument `do_train` as a boolean, to trigger model training:\n",
    "```Python\n",
    "parser.add_argument(\"--do_train\", action='store_true', help=\"Whether to run training.\")\n",
    "```\n",
    "\n",
    "We've modified the `do_train` argument to accept string inputs:\n",
    "```Python\n",
    "parser.add_argument(\"--do_train\", type=str2bool, nargs='?', const=True, default=False, help=\"Whether to run training.\")\n",
    "```\n",
    "\n",
    "with the function `str2bool()` defined in this way:\n",
    "\n",
    "```Python\n",
    "def str2bool(v):\n",
    "    if isinstance(v, bool):\n",
    "        return v\n",
    "    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n",
    "        return True\n",
    "    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n",
    "        return False\n",
    "    else:\n",
    "        raise argparse.ArgumentTypeError('Boolean value expected.')`\n",
    "```\n",
    "        \n",
    "We do this because it is not possible to pass boolean arguments into the Amazon SageMaker training job implicitly, as the orginal format was expecting; instead, we must pass an explicit value along with the `do_train` parameter. Similar changes applied the the `run_squad.py` script as well. We also made a minor change in `utils_glue.py` to allow using Python 3 to read data. Another change in the script is to print out the model evaluaton results into the CloudWatch history."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create requirements.txt for installing dependent packages in PyTorch container"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to create a `requirements.txt` file in the same directory (./train_scripts/) as the training scripts. The requirements.txt file should include packages required by the training script that are not pre-installed by default in the Amazon SageMaker PyTorch container. We will need to install 3 pacakges for this demo:\n",
    "\n",
    "*pytorch_transformers* <br>\n",
    "*tensorboardX* <br>\n",
    "*scikit-learn*\n",
    "\n",
    "A `requirements.txt` file is text file that contains a list of items that are installed via pip. When we launch training jobs, the Amazon SageMaker container automatically looks for a `requirements.txt` file in the script source folder and uses `pip install` to install all packages listed in that file. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Enviornment set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import boto3\n",
    "import sagemaker\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from time import gmtime, strftime \n",
    "from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner\n",
    "from sagemaker.pytorch import PyTorch\n",
    "\n",
    "role = sagemaker.get_execution_role() # we are using the notebook instance role for training in this example\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "bucket = sagemaker_session.default_bucket() # you can specify a bucket name here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example 1: fine tune MRPC dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload data to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input spec (in this case, just an S3 path): s3://sagemaker-us-east-1-835319576252/sagemaker/pytorch-transfomers/MRPC\n"
     ]
    }
   ],
   "source": [
    "task_name = 'MRPC'\n",
    "s3_prefix = 'sagemaker/pytorch-transfomers/' + task_name\n",
    "\n",
    "# data path in SageMaker notebook instance. Here we use the glue data MRPC for model fine tuning\n",
    "data_dir = os.path.join(os.path.join(os.getcwd(), 'glue_data'), task_name)\n",
    "\n",
    "# upload data to S3\n",
    "inputs_glue = sagemaker_session.upload_data(path=data_dir, bucket=bucket, key_prefix=s3_prefix)\n",
    "print('input spec (in this case, just an S3 path): {}'.format(inputs_glue))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data path for the SageMaker PyTorch container. We don't need to create an own container. \n",
    "container_data_dir = '/opt/ml/input/data/training'\n",
    "container_model_dir = '/opt/ml/model'\n",
    "\n",
    "# input arguments for the training script and initial values for some hyperparameters\n",
    "parameters = {\n",
    "    'model_type': 'bert',\n",
    "    'model_name_or_path' : 'bert-base-uncased',\n",
    "    'task_name': task_name,\n",
    "    'data_dir': container_data_dir,\n",
    "    'output_dir': container_model_dir,\n",
    "    'num_train_epochs': 1,\n",
    "    'per_gpu_train_batch_size': 64,\n",
    "    'per_gpu_eval_batch_size': 64,\n",
    "    'save_steps': 150,\n",
    "    'logging_steps': 150,\n",
    "    'do_train': True,\n",
    "    'do_eval': True,\n",
    "    'do_lower_case': True\n",
    "    # you can add more input arguments here\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Amazon SageMaker PyTorch framework\n",
    "\n",
    "train_instance_type = 'ml.p3.2xlarge'\n",
    "\n",
    "glue_estimator = PyTorch(entry_point='run_glue.py',\n",
    "                    source_dir = './train_scripts/', # the local directory stores all relevant scripts for modeling\n",
    "                    hyperparameters=parameters,\n",
    "                    role=role,\n",
    "                    framework_version='1.1.0',\n",
    "                    train_instance_count=1,\n",
    "                    train_instance_type=train_instance_type\n",
    "                   )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'s3://sagemaker-us-east-1-835319576252/sagemaker/pytorch-transfomers/MRPC'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check input data's s3 path\n",
    "inputs_glue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-03-18 17:43:50 Starting - Starting the training job...\n",
      "2020-03-18 17:43:52 Starting - Launching requested ML instances......\n",
      "2020-03-18 17:44:56 Starting - Preparing the instances for training.."
     ]
    }
   ],
   "source": [
    "# launch model training job\n",
    "glue_estimator.fit({'training': inputs_glue})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automatic model tuning - hyperparameter optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SageMaker uses the training job CloudWatch logs to extract metrics for hyperparameter optimization, processing the logs with a simple regular expression.\n",
    "\n",
    "For example, the `glue_estimator` training log has this printout for the model evaluation results:\n",
    "\n",
    "*Evaluation result =  {'acc_': 0.8455882352941176, 'f1_': 0.8941176470588236, 'acc_and_f1_': 0.8698529411764706}*\n",
    "\n",
    "Here, we want to use F1 score as the optimization metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# step 1: define optimization metric\n",
    "\n",
    "metric_definitions = [{'Name': 'f1_score',\n",
    "                       'Regex': '\\'f1_\\': ([0-9\\\\.]+)'}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import boto3\n",
    "import sagemaker\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from time import gmtime, strftime \n",
    "from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner\n",
    "from sagemaker.pytorch import PyTorch\n",
    "\n",
    "# step 2: define the hyperparameter range. Here we only tune the learning rate. \n",
    "\n",
    "hyperparameter_ranges = {\n",
    "        'learning_rate': ContinuousParameter(5e-06, 5e-04, scaling_type=\"Logarithmic\")       \n",
    "    }\n",
    "\n",
    "objective_metric_name = 'f1_score'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# step 3: launch the hyperparameter tuning job\n",
    "\n",
    "tuner = HyperparameterTuner(glue_estimator,\n",
    "                            objective_metric_name,\n",
    "                            hyperparameter_ranges,\n",
    "                            metric_definitions,\n",
    "                            strategy = 'Bayesian',\n",
    "                            objective_type = 'Maximize',\n",
    "                            max_jobs = 5,\n",
    "                            max_parallel_jobs = 5,\n",
    "                            early_stopping_type = 'Auto')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we can track the tuning job progress in the SageMaker console by the tuning_job_name\n",
    "glue_tuning_job_name = \"pt-bert-mrpc-bs-{}\".format(strftime(\"%d-%H-%M-%S\", gmtime())) \n",
    "\n",
    "# launch model tuning job\n",
    "tuner.fit({'training': inputs_glue}, job_name = glue_tuning_job_name)\n",
    "tuner.wait()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optional: check hyperparameter tuning results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "matplotlib.rc('xtick', labelsize=12) \n",
    "matplotlib.rc('ytick', labelsize=12) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tuner_metrics = sagemaker.HyperparameterTuningJobAnalytics(glue_tuning_job_name)\n",
    "hpo_report = tuner_metrics.dataframe().sort_values(['FinalObjectiveValue'], ascending=False)\n",
    "\n",
    "hpo_report['job_id'] = len(hpo_report) - hpo_report.index\n",
    "hpo_report.sort_values(by='job_id', inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the value of the best learning rate is extracted from the 'Hyperparameter tuning jobs' console\n",
    "\n",
    "best_lr = 6.470088521571402e-05 # update this value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "x = hpo_report['learning_rate']\n",
    "y = hpo_report['FinalObjectiveValue']\n",
    "plt.scatter(x, y, alpha=0.8)\n",
    "\n",
    "line_x = [best_lr, best_lr]\n",
    "line_y = [0, 1]\n",
    "plt.plot(line_x, line_y, linestyle='--', linewidth=1, color='orange')\n",
    "\n",
    "plt.xlim(5e-6, 6e-4)\n",
    "plt.xscale('log')\n",
    "plt.ylim(0.75, 0.95)\n",
    "plt.xlabel('Learning rate', fontsize=14)\n",
    "plt.ylabel('F1 score', fontsize=14)\n",
    "plt.title('MRPC: F1 score curve over learning rate', fontsize=14)\n",
    "plt.grid()\n",
    "#plt.savefig('figures/MRPC_F1_learning_rate.png', dpi=200, transparent=True, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "x = hpo_report['job_id']\n",
    "y = hpo_report['FinalObjectiveValue']\n",
    "plt.plot(x, y, alpha=0.8, linestyle='-', marker='o')\n",
    "plt.ylim(0.75, 0.95)\n",
    "plt.ylabel('F1 score', fontsize=14)\n",
    "plt.xlabel('Training job order index', fontsize=14)\n",
    "plt.title('MRPC: F1 score history', fontsize=14)\n",
    "plt.grid()\n",
    "#plt.savefig('figures/MRPC_F1_job_order.png', dpi=200, transparent=True, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "x = len(hpo_report) - hpo_report.index\n",
    "y = hpo_report['learning_rate']\n",
    "\n",
    "line_y = [best_lr, best_lr]\n",
    "line_x = [0, 40]\n",
    "plt.plot(x, y, alpha=0.8, linestyle='-', marker='o')\n",
    "plt.plot(line_x, line_y, linestyle='--', linewidth=1, color='orange')\n",
    "\n",
    "plt.xlim(0, 31)\n",
    "plt.ylim(5e-6, 6e-4)\n",
    "plt.yscale('log')\n",
    "plt.ylabel('Learning rate', fontsize=14)\n",
    "plt.xlabel('Training job order index', fontsize=14)\n",
    "plt.title('MRPC: learning rate search history', fontsize=14)\n",
    "plt.grid()\n",
    "#plt.savefig('figures/MRPC_lr_job_order.png', dpi=200, transparent=True, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example 2: fine tune SQuAD dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download SQuAD dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -P squad_data/\n",
    "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -P squad_data/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload data to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = 'squad'\n",
    "s3_prefix = 'sagemaker/pytorch-transfomers/' + task_name\n",
    "\n",
    "# data path in SageMaker notebook instance. Here we use the glue data MRPC for model fine tuning\n",
    "data_dir = os.path.join(os.getcwd(), 'squad_data')\n",
    "\n",
    "# upload data to S3\n",
    "inputs_squad = sagemaker_session.upload_data(path=data_dir, bucket=bucket, key_prefix=s3_prefix)\n",
    "\n",
    "print('input spec (in this case, just an S3 path): {}'.format(inputs_squad))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data path for the SageMaker PyTorch container. We don't need to create an own container. \n",
    "container_data_dir = '/opt/ml/input/data/training'\n",
    "container_model_dir = '/opt/ml/model'\n",
    "\n",
    "# input arguments for the training script and initial values for some hyperparameters\n",
    "parameters = {\n",
    "    'model_type': 'bert',\n",
    "    'model_name_or_path' : 'bert-base-uncased',\n",
    "    'train_file': container_data_dir+'/train-v1.1.json', # specify dataset version\n",
    "    'predict_file': container_data_dir+'/dev-v1.1.json',\n",
    "    'output_dir': container_model_dir,\n",
    "    'learning_rate': 5e-5,\n",
    "    'per_gpu_train_batch_size': 16,\n",
    "    'per_gpu_eval_batch_size': 16,\n",
    "    'num_train_epochs': 1,\n",
    "    'max_seq_length': 384,\n",
    "    'doc_stride': 128,\n",
    "    'save_steps': 10000,\n",
    "    'logging_steps': 10000,\n",
    "    'do_train': True,\n",
    "    'do_eval': True,\n",
    "    'do_lower_case': True,\n",
    "    'version_2_with_negative': False # False is for the 1.1 dataset. True is for SQuAD 2.0. \n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Amazon SageMaker PyTorch framework\n",
    "\n",
    "train_instance_type = 'ml.p3.2xlarge'\n",
    "\n",
    "squad_estimator = PyTorch(entry_point='run_squad.py',\n",
    "                    source_dir = './train_scripts/',  # the local directory stores all relevant scripts for modeling\n",
    "                    hyperparameters=parameters,\n",
    "                    role=role,\n",
    "                    framework_version='1.1.0',\n",
    "                    train_instance_count=1,\n",
    "                    train_instance_type=train_instance_type\n",
    "                   )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check input data's s3 path\n",
    "inputs_squad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# launch model training job\n",
    "squad_estimator.fit({'training': inputs_squad})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automatic model tuning - hyperparameter optimization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SageMaker uses the training job CloudWatch logs to extract metrics for hyperparameter optimization, processing the logs with a simple regular expression.\n",
    "\n",
    "For example, the `squad_estimator` training log has this printout for the model evaluation results:\n",
    "\n",
    "*Evaluation result ={'exact': 80.71901608325449, 'f1': 88.0493020797288, \n",
    "                     'total': 10570, 'HasAns_exact': 80.71901608325449, \n",
    "                     'HasAns_f1': 88.0493020797288, 'HasAns_total': 10570}*\n",
    "\n",
    "Here, we want to use F1 score as the optimization metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# step 1: define optimization metric\n",
    "\n",
    "metric_definitions = [{'Name': 'f1_score',\n",
    "                       'Regex': '\\'f1\\': ([0-9\\\\.]+)'}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# step 2: define the hyperparameter range. Here we only tune the learning rate. \n",
    "\n",
    "hyperparameter_ranges = {\n",
    "        'learning_rate': ContinuousParameter(1e-05, 5e-04, scaling_type=\"Logarithmic\")       \n",
    "    }\n",
    "\n",
    "objective_metric_name = 'f1_score'\n",
    "objective_type = 'Maximize'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# step 3: launch the hyperparameter tuning job\n",
    "\n",
    "tuner = HyperparameterTuner(squad_estimator,\n",
    "                            objective_metric_name,\n",
    "                            hyperparameter_ranges,\n",
    "                            metric_definitions,\n",
    "                            strategy = 'Bayesian',\n",
    "                            objective_type = 'Maximize',\n",
    "                            max_jobs = 5,\n",
    "                            max_parallel_jobs = 5,\n",
    "                            early_stopping_type = 'Auto')\n",
    "\n",
    "# we can track the tuning job progress in the SageMaker console by the tuning_job_name\n",
    "squad_tuning_job_name = \"pt-squad1-bs-{}\".format(strftime(\"%d-%H-%M-%S\", gmtime()))\n",
    "\n",
    "# launch model tuning job\n",
    "tuner.fit({'training': inputs_squad}, job_name=squad_tuning_job_name)\n",
    "tuner.wait()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squad_tuning_job_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optional: check hyperparameter tunning results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tuner_metrics = sagemaker.HyperparameterTuningJobAnalytics(squad_tuning_job_name)\n",
    "tuner_metrics.dataframe().sort_values(['FinalObjectiveValue'], ascending=False).head(10)\n",
    "\n",
    "hpo_report = tuner_metrics.dataframe().sort_values(['FinalObjectiveValue'], ascending=False)\n",
    "hpo_report['job_id'] = len(hpo_report) - hpo_report.index\n",
    "hpo_report.sort_values(by='job_id', inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the value of the best learning rate is extracted from the 'Hyperparameter tuning jobs' console\n",
    "\n",
    "best_lr = 5.7330400829294637e-05 # update this value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "x = hpo_report['learning_rate']\n",
    "y = hpo_report['FinalObjectiveValue']\n",
    "plt.scatter(x, y, alpha=0.8)\n",
    "\n",
    "line_x = [best_lr, best_lr]\n",
    "line_y = [0, 1]\n",
    "plt.plot(line_x, line_y, linestyle='--', linewidth=1, color='orange')\n",
    "\n",
    "plt.xlim(5e-6, 6e-4)\n",
    "plt.xscale('log')\n",
    "plt.ylim(0.75, 0.95)\n",
    "plt.xlabel('Learning rate', fontsize=14)\n",
    "plt.ylabel('F1 score', fontsize=14)\n",
    "plt.title('MRPC: F1 score curve over learning rate', fontsize=14)\n",
    "plt.grid()\n",
    "#plt.savefig('figures/SQUAD_F1_learning_rate.png', dpi=200, transparent=True, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "x = hpo_report['job_id']\n",
    "y = hpo_report['FinalObjectiveValue']\n",
    "plt.plot(x, y, alpha=0.8, linestyle='-', marker='o')\n",
    "plt.ylim(0.75, 0.95)\n",
    "plt.ylabel('F1 score', fontsize=14)\n",
    "plt.xlabel('Training job order index', fontsize=14)\n",
    "plt.title('MRPC: F1 score history', fontsize=14)\n",
    "plt.grid()\n",
    "#plt.savefig('figures/SQUAD_F1_job_order.png', dpi=200, transparent=True, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "x = len(hpo_report) - hpo_report.index\n",
    "y = hpo_report['learning_rate']\n",
    "\n",
    "line_y = [best_lr, best_lr]\n",
    "line_x = [0, 40]\n",
    "plt.plot(x, y, alpha=0.8, linestyle='-', marker='o')\n",
    "plt.plot(line_x, line_y, linestyle='--', linewidth=1, color='orange')\n",
    "\n",
    "plt.xlim(0, 31)\n",
    "plt.ylim(5e-6, 6e-4)\n",
    "plt.yscale('log')\n",
    "plt.ylabel('Learning rate', fontsize=14)\n",
    "plt.xlabel('Training job order index', fontsize=14)\n",
    "plt.title('MRPC: learning rate search history', fontsize=14)\n",
    "plt.grid()\n",
    "#plt.savefig('figures/SQUAD_lr_job_order.png', dpi=200, transparent=True, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
